import csv

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import visdom
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
from sklearn import datasets as sk_ds
from sklearn.model_selection import train_test_split
import argparse
from models import SimpleMLP, SimpleCNN, LeNet, L_LeNet, C_LeNet
from BCD import GradBCD
from DP_BCD import DPGradBCD
from data import data_generator
from fastDP import PrivacyEngine



if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('cpu')        # encounter some bugs
else:
    device = torch.device('cpu')



# MLP parameters
def initialize_weights(model):
    for module in model.modules():
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight.data, 0.01, 0.1)
            if module.bias is not None:
                torch.nn.init.ones_(module.bias.data)


# def initialize_weights(model):
#
#     for module in model.modules():
#         if isinstance(module, nn.Linear):
#             torch.nn.init.ones_(module.weight.data)
#             if module.bias is not None:
#                 torch.nn.init.ones_(module.bias.data)


def train(train_loader, test_loader, device, model, criterion, optimizer, args):
    n_epochs = args.epochs
    scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-2)
    iter_start = 0
    index = []
    test_acc_record = []

    if args.model == "lenet" or args.model == "l_lenet" or args.model == "cifar_lenet":
        model_name = "cnn"
    else:
        model_name = "mlp"

    for epoch in range(n_epochs):
        model.train()
        for i, (images, labels) in enumerate(train_loader):

            images, labels = images.to(device), labels.to(device)
            one_hot_labels = torch.zeros(images.size(0), output_dim).to(device)
            one_hot_labels.scatter_(1, labels.unsqueeze(1), 1)

            if args.optimizer == "sgd" or args.optimizer == "dp_sgd":
                if model_name == "cnn":
                    images = nn.ZeroPad2d(2)(images.reshape(args.batch_size, 1, 28, 28))
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, one_hot_labels.float())
                loss.backward()
                optimizer.step()

            elif args.optimizer == "sbcd" or args.optimizer == "dp_sbcd":

                if (args.model == "lenet" or args.model == "l_lenet") and (args.dataset =="mnist" or args.dataset=="fashion_mnist"):
                    images = nn.ZeroPad2d(2)(images.reshape(args.batch_size, 1, 28, 28))
                elif args.dataset == "cifar10":
                    pass
                else:
                    images = images.reshape(images.size(0), -1).to(torch.float)
                def closure():
                    return images, one_hot_labels

                optimizer.step(closure)
                scheduler.step()

            else:
                correct = 0
                total = 0
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            if args.optimizer == "sgd":
                print(f'Epoch [{epoch + 1}/{n_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')

        if epoch % 1 == 0:
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in test_loader:
                    images, labels = images.to(device).to(torch.float), labels.to(device)
                    if args.loss == "mse":
                        one_hot_labels = torch.zeros(images.size(0), 10).to(device)
                        one_hot_labels.scatter_(1, labels.unsqueeze(1), 1)

                    if model_name == "cnn":
                        if args.dataset == "mnist" or args.dataset=="fashion_mnist":
                            images = nn.ZeroPad2d(2)(images.reshape(1, 1, 28, 28))

                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                test_acc_record.append(correct / total)
                iter_start += 1
                index.append(iter_start)
                print(f'{iter_start}: test acc  {100 * correct / total:.2f}')
                # print(f'{100 * correct / total:.2f}')
                if iter_start % 10 == 0:
                    np.save(args.savepath + '.npy', test_acc_record)





if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train i-ResNet/ResNet on Cifar')
    parser.add_argument('--vis_port', default=8097, type=int, help="port for visdom")
    parser.add_argument('--vis_server', default="localhost", type=str, help="server for visdom")
    parser.add_argument('--dataset', type=str, default='mnist',
                        choices=['mnist', 'synthetic','purchase100', 'adults','mnist-1k','cifar10','fashion_mnist'], help='dataset name')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size')
    parser.add_argument('--optimizer', type=str, default='sgd', choices=['sgd', 'sbcd', 'dp_sbcd','dp_sgd'])
    parser.add_argument('--epochs', type=int, default='30')
    parser.add_argument('--alpha', type=float, default=1.0)
    parser.add_argument('--rho', type=float, default=1.0)
    parser.add_argument('--gamma', type=float, default=1.0)
    parser.add_argument('--lips', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--shift', type=int, default=0)
    parser.add_argument('--wd', type=float, default=0.01)
    parser.add_argument('--ns', type=float, default=0)
    parser.add_argument('--ne', type=float, default=-1)
    parser.add_argument('--batch_num', type=int, default=1)
    parser.add_argument('--model', type=str, default="mlp")
    parser.add_argument('--loss', type=str, default="mse")
    parser.add_argument('--savepath', type=str, default="res")
    parser.add_argument('--xd', type=float, default="1.0")

    args = parser.parse_args()
    # viz = visdom.Visdom(port=args.vis_port, server="http://" + args.vis_server)
    # assert viz.check_connection(), "Could not make visdom"

    train_l, test_l, input_dim, output_dim = data_generator(args, device)

    if train_l is None or test_l is None:
        print('Please provide train and test datasets')
        pass

    layer = 0
    conv = -1
    arch = ""
    data_input_dim = None
    if args.dataset == "cifar10":
        data_input_dim = (3, 32, 32)
    if args.dataset == "mnist" or args.dataset == "fashion_mnist":
        data_input_dim = (1, 32, 32)

    if args.dataset == 'mnist' or args.dataset == "fashion_mnist":
        arch = "large"
    else:
        arch = "small"
    if args.model == "mlp":
        model = SimpleMLP(input_dim, output_dim, arch=arch).to(device)
        layer = 4

    elif args.dataset == "cifar10":
        model = C_LeNet().to(device)
        layer = 5
        conv = 2

    elif args.model == "lenet":
        model = LeNet().to(device)
        layer = 5
        conv = 2
    elif args.model == "l_lenet":
        model = L_LeNet().to(device)
        layer = 5
        conv = 2

    else:
        model = SimpleCNN(input_dim, output_dim, arch=arch).to(device)
        layer = 4
    initialize_weights(model)

    # super parameters

    if args.loss == "mse":
        criterion = nn.MSELoss().to(device)

    if args.loss == "ce":
        criterion = nn.CrossEntropyLoss().to(device)


    # alpha = torch.ones(layer + 1) * args.alpha
    if  args.optimizer == "dp_sbcd":

        for param in model.parameters():
            param.requires_grad = False

        params_dict = [{'params': param, 'name': name} for name, param in model.named_parameters()]

        optimizer = DPGradBCD(params_dict, conv=conv, layers=layer, lossf=args.loss,rho=args.rho,gamma=args.gamma,
                              batch_size=args.batch_size, device=device, alpha=args.alpha, batch_num=args.batch_num,
                              epochs=args.epochs, lips=args.lips, lr=args.lr, ns=args.ns, ne=args.ne, input_dim=data_input_dim
        )
    if args.optimizer == "dp_sgd" or args.optimizer == "sgd":
        # init
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform(m.weight)
                m.bias.data.fill_(0.01)
        model.apply(init_weights)

        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
        if args.optimizer == "dp_sgd":
            privacy_engine = PrivacyEngine(
                model,
                batch_size=args.batch_size,
                sample_size=args.batch_size * args.batch_num,
                epochs=args.epochs,
                target_epsilon=0.1566,
                target_delta=0.00001,
                data_loader=train_l,
                clipping_fn='automatic',
                clipping_mode='MixOpt',
                origin_params=None,
                clipping_style='all-layer',
            )
    else:
        pass
    train(train_l, test_l, device, model, criterion, optimizer, args)
